{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Code to Evaluate: \n",
    "  - Proposed model\n",
    "  - single-task baselines\n",
    "  - fine-tuning baselines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note - Please set the appropriate dataset, model and checkpoint paths before using\n",
    "\n",
    "This notebook contains commands for CS->BDD->IDD and CS->IDD->BDD settings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "import importlib\n",
    "import time\n",
    "\n",
    "from PIL import Image\n",
    "from argparse import ArgumentParser\n",
    "\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.transforms import Compose, CenterCrop, Normalize, Resize\n",
    "from torchvision.transforms import ToTensor, ToPILImage\n",
    "\n",
    "from dataset_custom import cityscapes, IDD, BDD # original label id gt annotations\n",
    "\n",
    "from erfnet_RA_parallel import Net as Net_RAP # proposed model\n",
    "from erfnet import ERFNet as ERFNet_ind # single-task models \n",
    "from erfnet_ftp1 import Net as ERFNet_ft1 # 1st stage FT/FE (Eg. CS->BDD)\n",
    "from erfnet_ftp2 import Net as ERFNet_ft2 # 2nd stage FT/FE (CS|BDD->IDD)\n",
    "\n",
    "from transform import Relabel, ToLabel, Colorize # modify IDD label ids if saving colour maps. otherwise its fine. \n",
    "from iouEval import iouEval, getColorEntry\n",
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_CHANNELS = 3\n",
    "NUMC_city = 20\n",
    "NUMC_bdd = 20\n",
    "NUMC_idd = 27\n",
    "\n",
    "image_transform = ToPILImage()\n",
    "input_transform = Compose([\n",
    "    Resize([512,1024], Image.BILINEAR),\n",
    "    ToTensor(),\n",
    "])\n",
    "target_transform_cityscapes = Compose([\n",
    "    Resize([512,1024],  Image.NEAREST),\n",
    "    ToLabel(),\n",
    "    Relabel(255, NUMC_city-1),   \n",
    "])\n",
    "target_transform_IDD = Compose([\n",
    "    Resize([512,1024],  Image.NEAREST),\n",
    "    ToLabel(),\n",
    "    Relabel(255, NUMC_idd-1),  \n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pass dataset name, get val_loader, criterion with suitable weight\n",
    "def criterion_fn(data_name='cityscapes'): \n",
    "    weight_IDD = torch.tensor([3.235635601598852, 6.76221624390441, 9.458242359884549, 9.446818215454014, 9.947040673126763, 9.789672819856547, 9.476665808564432, 10.465565126694731, 9.59189547383129, \\\n",
    "                               7.637805282159825, 8.990899026692638, 9.26222234098628, 10.265657138809514, 9.386517631614392, 8.357391489170013, 9.910382864314824, 10.389977663948363, 8.997422571963602, \\\n",
    "                               10.418070541191673, 10.483262606962834, 9.511436923349441, 7.597725385711079, 6.1734896019878205, 9.787631041755187, 3.9178330193378708, 4.417448652936843, 10.313160683418731])\n",
    "\n",
    "    weight_BDD = torch.tensor([3.6525147483016243, 8.799815287822142, 4.781908267406055, 10.034828238618045, 9.5567865464289, 9.645099012085169, 10.315292989325766, 10.163473632969513, 4.791692009441432, \\\n",
    "                               9.556915153488912, 4.142994047786311, 10.246903827488143, 10.47145010979545, 6.006704177894196, 9.60620532303246, 9.964959813857726, 10.478333987902301, 10.468010534454706, \\\n",
    "                               10.440929141422366, 3.960822533003462])\n",
    "\n",
    "    weight_city = torch.tensor([2.8159904084894922, 6.9874672455551075, 3.7901719017455604, 9.94305485286704, 9.77037625072462, 9.511470001589007, 10.310780572569994, 10.025305236316246, 4.6341256102158805, \\\n",
    "                               9.561389195953845, 7.869695292372276, 9.518873463871952, 10.374050047877898, 6.662394711556909, 10.26054487392723, 10.28786101490449, 10.289883605859952, 10.405463349170795, \\\n",
    "                               10.138502340710136, 5.131658171724055])\n",
    "\n",
    "    weight_city[-1] = 0\n",
    "    weight_BDD[-1] = 0\n",
    "    weight_IDD[-1] = 0\n",
    "\n",
    "    CS_datadir = '/ssd_scratch/cvit/prachigarg/cityscapes/'\n",
    "    BDD_datadir = '/ssd_scratch/cvit/prachigarg/bdd100k/seg/'\n",
    "    IDD_datadir = '/ssd_scratch/cvit/prachigarg/IDD_Segmentation/'\n",
    "\n",
    "    if data_name == 'cityscapes':\n",
    "        dataset_val = cityscapes(CS_datadir, input_transform,\n",
    "                         target_transform_cityscapes, 'val')\n",
    "        weight = weight_city\n",
    "    elif data_name == 'IDD':\n",
    "        dataset_val = IDD(IDD_datadir, input_transform,\n",
    "                         target_transform_IDD, 'val')\n",
    "        weight = weight_IDD\n",
    "    elif data_name == 'BDD':\n",
    "        dataset_val = BDD(BDD_datadir, input_transform,\n",
    "                         target_transform_cityscapes, 'val')\n",
    "        weight = weight_BDD\n",
    "\n",
    "    loader_val = DataLoader(dataset_val, num_workers=4,\n",
    "                            batch_size=1, shuffle=False)\n",
    "\n",
    "    weight = weight.cuda()\n",
    "    criterion = nn.CrossEntropyLoss(weight=weight)\n",
    "    \n",
    "    return loader_val, criterion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval(model, dataset_loader, criterion, task, num_classes):\n",
    "    model.eval()\n",
    "    epoch_loss_val = []\n",
    "    num_cls = num_classes[task]\n",
    "\n",
    "    iouEvalVal = iouEval(num_cls, num_cls-1)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for step, (images, labels, filename, filenameGt) in enumerate(dataset_loader):\n",
    "            # inputs size: torch.Size([1, 20, 512, 1024])\n",
    "            start_time = time.time()\n",
    "            inputs = images.cuda()\n",
    "            targets = labels.cuda()\n",
    "\n",
    "            outputs = model(inputs, task)\n",
    "\n",
    "            loss = criterion(outputs, targets[:, 0])\n",
    "            epoch_loss_val.append(loss.item())\n",
    "\n",
    "            iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)\n",
    "\n",
    "    average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)\n",
    "\n",
    "    iouVal = 0\n",
    "    iouVal, iou_classes = iouEvalVal.getIoU()\n",
    "\n",
    "#     print('check val fn, loss, acc: ', iouVal) \n",
    "        \n",
    "    return iou_classes, iouVal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/ssd_scratch/cvit/prachigarg/cityscapes/leftImg8bit/val\n",
      "/ssd_scratch/cvit/prachigarg/bdd100k/seg/images/val\n",
      "/ssd_scratch/cvit/prachigarg/IDD_Segmentation/leftImg8bit/val\n"
     ]
    }
   ],
   "source": [
    "loader_val_CS, criterion_CS = criterion_fn('cityscapes')\n",
    "loader_val_BDD, criterion_BDD = criterion_fn('BDD')\n",
    "loader_val_IDD, criterion_IDD = criterion_fn('IDD')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### KLD-RAPFT-dlr2 - PROPOSED RESULTS "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hi, inside erfnet_RA_parallel 0 1\n",
      "tensor(0.7182, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# STEP 1 CS (with the DS-RAP and DS-BN)\n",
    "model_step1 = Net_RAP([20], 1, 0) \n",
    "model_step1 = torch.nn.DataParallel(model_step1).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/RAP_FT_KLD/step1cs/model_best_cityscapes_erfnet_RA_parallel_150_6RAP_FT_step1.pth.tar')\n",
    "model_step1.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_step1_CS, val_acc_step1_CS = eval(model_step1,loader_val_CS, criterion_CS, 0, [20])\n",
    "print(val_acc_step1_CS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hi, inside erfnet_RA_parallel 1 2\n",
      "tensor(0.6521, dtype=torch.float64)\n",
      "tensor(0.5573, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# STEP 2 CS->BDD \n",
    "model_step2 = Net_RAP([20, 20], 2, 1) \n",
    "model_step2 = torch.nn.DataParallel(model_step2).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/RAP_FT_KLD/step2bdd/model_best_BDD_erfnet_RA_parallel_150_6RAP_FT_dlr2-5e-6-KLD-ouput-1e-1_step2.pth.tar')\n",
    "model_step2.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_step2_CS, val_acc_step2_CS = eval(model_step2, loader_val_CS, criterion_CS, 0, [20, 20])\n",
    "print(val_acc_step2_CS)\n",
    "\n",
    "iou_classes_step2_BDD, val_acc_step2_BDD = eval(model_step2, loader_val_BDD, criterion_BDD, 1, [20, 20])\n",
    "print(val_acc_step2_BDD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hi, inside erfnet_RA_parallel 1 2\n",
      "tensor(0.6458, dtype=torch.float64)\n",
      "tensor(0.5911, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# STEP 2 CS->IDD\n",
    "model_step2 = Net_RAP([20, 27], 2, 1) \n",
    "model_step2 = torch.nn.DataParallel(model_step2).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/RAP_FT_KLD/step2idd/model_best_IDD_erfnet_RA_parallel_150_6RAP_FT_dlr2-5e-6-KLD-ouput-1e-1-IDD_step2.pth.tar')\n",
    "model_step2.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_step2_CS, val_acc_step2_CS = eval(model_step2, loader_val_CS, criterion_CS, 0, [20, 27])\n",
    "print(val_acc_step2_CS)\n",
    "\n",
    "iou_classes_step2_IDD, val_acc_step2_IDD = eval(model_step2, loader_val_IDD, criterion_IDD, 1, [20, 27])\n",
    "print(val_acc_step2_IDD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hi, inside erfnet_RA_parallel 2 3\n",
      "tensor(0.5919, dtype=torch.float64)\n",
      "tensor(0.4966, dtype=torch.float64)\n",
      "tensor(0.5916, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# STEP 3 CS|BDD->IDD\n",
    "model_step3 = Net_RAP([20, 20, 27], 3, 2)\n",
    "model_step3 = torch.nn.DataParallel(model_step3).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/RAP_FT_KLD/CS1_BDD2_IDD3/model_best_IDD_erfnet_RA_parallel_150_6RAP_FT_dlr2-5e-6-KLD-output-1e-1_step3.pth.tar')\n",
    "model_step3.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_step3_CS, val_acc_step3_CS = eval(model_step3, loader_val_CS, criterion_CS, 0, [20, 20, 27])\n",
    "print(val_acc_step3_CS)\n",
    "\n",
    "iou_classes_step3_BDD, val_acc_step3_BDD = eval(model_step3, loader_val_BDD, criterion_BDD, 1, [20, 20, 27])\n",
    "print(val_acc_step3_BDD)\n",
    "\n",
    "iou_classes_step3_IDD, val_acc_step3_IDD = eval(model_step3, loader_val_IDD, criterion_IDD, 2, [20, 20, 27])\n",
    "print(val_acc_step3_IDD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hi, inside erfnet_RA_parallel 2 3\n",
      "tensor(0.6255, dtype=torch.float64)\n",
      "tensor(0.5385, dtype=torch.float64)\n",
      "tensor(0.5590, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# STEP 3 CS|IDD->BDD\n",
    "model_step3 = Net_RAP([20, 27, 20], 3, 2) \n",
    "model_step3 = torch.nn.DataParallel(model_step3).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/RAP_FT_KLD/CS1_IDD2_BDD3/checkpoint_BDD_erfnet_RA_parallel_150_6OURS-CS1_IDD2_BDD3_step3.pth.tar')\n",
    "model_step3.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_step3_CS, val_acc_step3_CS = eval(model_step3, loader_val_CS, criterion_CS, 0, [20, 27, 20])\n",
    "print(val_acc_step3_CS)\n",
    "\n",
    "iou_classes_step3_IDD, val_acc_step3_IDD = eval(model_step3, loader_val_IDD, criterion_IDD, 1, [20, 27, 20])\n",
    "print(val_acc_step3_IDD)\n",
    "\n",
    "iou_classes_step3_BDD, val_acc_step3_BDD = eval(model_step3, loader_val_BDD, criterion_BDD, 2, [20, 27, 20])\n",
    "print(val_acc_step3_BDD)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Single-Task Baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.7255, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# STEP 1, single task\n",
    "# Cityscapes\n",
    "model_step1 = ERFNet_ind(20) \n",
    "model_step1 = torch.nn.DataParallel(model_step1).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/single-task/model_best_cityscapes_prenc.pth.tar')\n",
    "model_step1.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_ST_cs, val_acc_ST_cs = eval(model_step1, loader_val_CS, criterion_CS, 0, [20])\n",
    "print(val_acc_ST_cs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.5410, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# BDD100k\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/single-task/checkpoint_BDD_prenc.pth.tar')\n",
    "model_step1.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_ST_bdd, val_acc_ST_bdd = eval(model_step1, loader_val_BDD, criterion_BDD, 0, [20])\n",
    "print(val_acc_ST_bdd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.6197, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# IDD\n",
    "model_step1 = ERFNet_ind(27)\n",
    "model_step1 = torch.nn.DataParallel(model_step1).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/single-task/checkpoint_IDD_prenc.pth.tar')\n",
    "model_step1.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_ST_idd, val_acc_ST_idd = eval(model_step1, loader_val_IDD, criterion_IDD, 0, [27])\n",
    "print(val_acc_ST_idd)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_finetune(model, dataset_loader, criterion, task, num_classes):\n",
    "    model.eval()\n",
    "    epoch_loss_val = []\n",
    "    num_cls = num_classes[task]\n",
    "\n",
    "    iouEvalVal = iouEval(num_cls, num_cls-1)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for step, (images, labels, filename, filenameGt) in enumerate(dataset_loader):\n",
    "            # inputs size: torch.Size([1, 20, 512, 1024])\n",
    "            start_time = time.time()\n",
    "            inputs = images.cuda()\n",
    "            targets = labels.cuda()\n",
    "            if task == 0:\n",
    "                outputs = model(inputs, decoder_old=True, decoder_new=False)\n",
    "            elif task == 1:\n",
    "                outputs = model(inputs, decoder_old=False, decoder_new=True)\n",
    "\n",
    "            loss = criterion(outputs, targets[:, 0])\n",
    "            epoch_loss_val.append(loss.item())\n",
    "\n",
    "            iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)\n",
    "\n",
    "    average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)\n",
    "\n",
    "    iouVal = 0\n",
    "    iouVal, iou_classes = iouEvalVal.getIoU()\n",
    "\n",
    "    print('check val fn, loss, acc: ', iouVal) \n",
    "        \n",
    "    return iou_classes, iouVal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_finetune3(model, dataset_loader, criterion, task, num_classes=[20, 20, 27]):\n",
    "    model.eval()\n",
    "    epoch_loss_val = []\n",
    "    num_cls = num_classes[task]\n",
    "\n",
    "    iouEvalVal = iouEval(num_cls, num_cls-1)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for step, (images, labels, filename, filenameGt) in enumerate(dataset_loader):\n",
    "            # inputs size: torch.Size([1, 20, 512, 1024])\n",
    "            start_time = time.time()\n",
    "            inputs = images.cuda()\n",
    "            targets = labels.cuda()\n",
    "            if task == 0:\n",
    "                outputs = model(inputs, decoder_old1=True, decoder_old2=False, decoder_new=False)\n",
    "            elif task == 1:\n",
    "                outputs = model(inputs, decoder_old1=False, decoder_old2=True, decoder_new=False)\n",
    "            elif task == 2:\n",
    "                outputs = model(inputs, decoder_old1=False, decoder_old2=False, decoder_new=True)\n",
    "\n",
    "            loss = criterion(outputs, targets[:, 0])\n",
    "            epoch_loss_val.append(loss.item())\n",
    "\n",
    "            iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)\n",
    "\n",
    "    average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)\n",
    "\n",
    "    iouVal = 0\n",
    "    iouVal, iou_classes = iouEvalVal.getIoU()\n",
    "\n",
    "    print('check val fn, loss, acc: ', iouVal) \n",
    "        \n",
    "    return iou_classes, iouVal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "check val fn, loss, acc:  tensor(0.4005, dtype=torch.float64)\n",
      "tensor(0.4005, dtype=torch.float64)\n",
      "check val fn, loss, acc:  tensor(0.5274, dtype=torch.float64)\n",
      "tensor(0.5274, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Step 2: Finetune CS model on BDD\n",
    "model_step2 = ERFNet_ft1(num_classes_old=20, num_classes_new=20)\n",
    "model_step2 = torch.nn.DataParallel(model_step2).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/FineTune/checkpoint_erfnet_ftp1_150_6_Finetune-CStoBDD-final.pth.tar')\n",
    "model_step2.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_step2_FT1_CS, val_acc_step2_FT1_CS = eval_finetune(model_step2, loader_val_CS, criterion_CS, 0, [20, 20])\n",
    "print(val_acc_step2_FT1_CS)\n",
    "\n",
    "iou_classes_step2_FT1_BDD, val_acc_step2_FT1_BDD = eval_finetune(model_step2, loader_val_BDD, criterion_BDD, 1, [20, 20])\n",
    "print(val_acc_step2_FT1_BDD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "check val fn, loss, acc:  tensor(0.3049, dtype=torch.float64)\n",
      "tensor(0.3049, dtype=torch.float64)\n",
      "check val fn, loss, acc:  tensor(0.3205, dtype=torch.float64)\n",
      "tensor(0.3205, dtype=torch.float64)\n",
      "check val fn, loss, acc:  tensor(0.6103, dtype=torch.float64)\n",
      "tensor(0.6103, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Step 3: Finetune CS|BDD model on IDD\n",
    "model_step3 = ERFNet_ft2(20, 20, 27) \n",
    "model_step3 = torch.nn.DataParallel(model_step3).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/FineTune/model_best_erfnet_ftp2_150_6_Finetune-code-CSBDDtoIDD-FT.pth.tar')\n",
    "model_step3.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_step3_CS, val_acc_step3_CS = eval_finetune3(model_step3, loader_val_CS, criterion_CS, 0, [20, 20, 27])\n",
    "print(val_acc_step3_CS)\n",
    "\n",
    "iou_classes_step3_BDD, val_acc_step3_BDD = eval_finetune3(model_step3, loader_val_BDD, criterion_BDD, 1, [20, 20, 27])\n",
    "print(val_acc_step3_BDD)\n",
    "\n",
    "iou_classes_step3_IDD, val_acc_step3_IDD = eval_finetune3(model_step3, loader_val_IDD, criterion_IDD, 2, [20, 20, 27])\n",
    "print(val_acc_step3_IDD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "check val fn, loss, acc:  tensor(0.3619, dtype=torch.float64)\n",
      "tensor(0.3619, dtype=torch.float64)\n",
      "check val fn, loss, acc:  tensor(0.2630, dtype=torch.float64)\n",
      "tensor(0.2630, dtype=torch.float64)\n",
      "check val fn, loss, acc:  tensor(0.5337, dtype=torch.float64)\n",
      "tensor(0.5337, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "# Step 3: Finetune CS|IDD model on BDD\n",
    "model_step3 = ERFNet_ft2(20, 27, 20) \n",
    "model_step3 = torch.nn.DataParallel(model_step3).cuda()\n",
    "saved_model = torch.load('/home2/prachigarg/temp_MDILSS/checkpoints/FineTune/model_best_erfnet_ftp2_150_6_FT_CS1_IDD2_BDD3.pth.tar')\n",
    "model_step3.load_state_dict(saved_model['state_dict'])\n",
    "\n",
    "iou_classes_step3_CS, val_acc_step3_CS = eval_finetune3(model_step3, loader_val_CS, criterion_CS, 0, [20, 27, 20])\n",
    "print(val_acc_step3_CS)\n",
    "\n",
    "iou_classes_step3_IDD, val_acc_step3_IDD = eval_finetune3(model_step3, loader_val_IDD, criterion_IDD, 1, [20, 27, 20])\n",
    "print(val_acc_step3_IDD)\n",
    "\n",
    "iou_classes_step3_BDD, val_acc_step3_BDD = eval_finetune3(model_step3, loader_val_BDD, criterion_BDD, 2, [20, 27, 20])\n",
    "print(val_acc_step3_BDD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
